import torch
import torch.nn as nn
import torch.nn.functional as F

# External library imports
from src.external_libs.limoe.main import LiMoE
from src.external_libs.RealNVP.realnvp import RealNVPImageFlow, RealNVPSignalFlow

# --- Positional Encoding ---
def get_positional_encoding(seq_len, dim, device='cpu'):
    pos = torch.arange(seq_len, dtype=torch.float32, device=device).unsqueeze(1)
    i = torch.arange(dim, dtype=torch.float32, device=device).unsqueeze(0)
    angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / dim)
    angle_rads = pos * angle_rates
    pe = torch.zeros(seq_len, dim, device=device)
    pe[:, 0::2] = torch.sin(angle_rads[:, 0::2])
    pe[:, 1::2] = torch.cos(angle_rads[:, 1::2])
    return pe

class ProjectionHead(nn.Module):
    def __init__(self, in_dim=768, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.ReLU(),
            nn.Linear(in_dim, out_dim)
        )
    def forward(self, x):
        return self.net(x)

class MADModel(nn.Module):
    def __init__(self,
                 num_classes: int,
                 img_patch_flat_dim: int,
                 sig_patch_dim: int,
                 num_img_patches: int,
                 num_sig_patches: int,
                 embed_dim: int = 768,
                 # LiMoE parameters
                 limoe_depth: int = 6,
                 limoe_heads: int = 8,
                 limoe_dim_head: int = None,
                 limoe_num_experts: int = 4,
                 limoe_ff_mult: int = 2,
                 limoe_top_k: int = 2,
                 limoe_dropout: float = 0.1,
                 # RealNVP parameters
                 realnvp_img_layers: int = 6,
                 realnvp_sig_layers: int = 8,
                 # SupCon Projection Head
                 proj_head_out_dim: int = 128):
        super().__init__()

        self.use_img_modality = num_img_patches > 0 and img_patch_flat_dim > 0
        self.use_sig_modality = num_sig_patches > 0 and sig_patch_dim > 0

        if not self.use_img_modality and not self.use_sig_modality:
            raise ValueError("At least one modality (image or signal) must be used.")

        self.num_img_patches = num_img_patches if self.use_img_modality else 0
        self.num_sig_patches = num_sig_patches if self.use_sig_modality else 0
        self.embed_dim = embed_dim

        # --- Conditionally create RealNVP modules ---
        if self.use_img_modality:
            self.img_projection = nn.Linear(img_patch_flat_dim, embed_dim)
            if realnvp_img_layers > 0:
                self.flow_img = RealNVPImageFlow(dim=embed_dim, num_layers=realnvp_img_layers)
            else:
                self.flow_img = nn.Identity() # Use Identity as a placeholder if RealNVP is disabled

        if self.use_sig_modality:
            self.sig_projection = nn.Linear(sig_patch_dim, embed_dim)
            if realnvp_sig_layers > 0:
                self.flow_sig = RealNVPSignalFlow(dim=embed_dim, num_layers=realnvp_sig_layers)
            else:
                self.flow_sig = nn.Identity() # Use Identity as a placeholder if RealNVP is disabled

        encoder_total_seq_length = self.num_img_patches + self.num_sig_patches
        
        self.encoder = LiMoE(
            dim=embed_dim,
            depth=limoe_depth,
            heads=limoe_heads,
            dim_head=limoe_dim_head if limoe_dim_head else embed_dim // limoe_heads,
            num_experts=limoe_num_experts,
            ff_mult=limoe_ff_mult,
            dropout=limoe_dropout,
            top_k_experts=limoe_top_k,
            seq_length=encoder_total_seq_length,
            # Some LiMoE arguments are kept for library compatibility, though they may not be used
            num_tokens=10000, patch_size=16, image_size=256, channels=3, dense_encoder_depth=1
        )
        
        self.proj_head_supcon = ProjectionHead(in_dim=embed_dim, out_dim=proj_head_out_dim)
        self.classification_head = nn.Linear(embed_dim, num_classes)

    def forward(self, x_img: torch.Tensor | None, x_sig: torch.Tensor | None, for_supcon: bool = False):
        all_tokens = []
        device = next(self.parameters()).device

        # --- Process Image Modality ---
        if self.use_img_modality and x_img is not None:
            img_projected = self.img_projection(x_img)
            img_pe = get_positional_encoding(self.num_img_patches, self.embed_dim, device=device)
            img_tokens = img_projected + img_pe.unsqueeze(0)
            z_img = self.flow_img(img_tokens) # Performs Identity operation if realnvp_img_layers=0
            all_tokens.append(z_img)

        # --- Process Signal Modality ---
        if self.use_sig_modality and x_sig is not None:
            sig_projected = self.sig_projection(x_sig)
            sig_pe = get_positional_encoding(self.num_sig_patches, self.embed_dim, device=device)
            sig_tokens = sig_projected + sig_pe.unsqueeze(0)
            z_sig = self.flow_sig(sig_tokens) # Performs Identity operation if realnvp_sig_layers=0
            all_tokens.append(z_sig)
        
        if not all_tokens:
            raise ValueError("No valid input tokens to process. Both inputs might be None.")
        
        # Fuse tokens and pass through the shared encoder
        z_all = torch.cat(all_tokens, dim=1)
        # Use the `precomputed_embeddings` argument to match the LiMoE library's input format
        encoded_sequence = self.encoder(precomputed_embeddings=z_all)
        
        # Pooling
        pooled_output = torch.mean(encoded_sequence, dim=1)

        # --- Return different outputs based on the for_supcon flag ---
        if for_supcon:
            # For pre-training, return normalized contrastive vectors
            return F.normalize(self.proj_head_supcon(pooled_output), dim=1)
        else:
            # For fine-tuning and inference, return classification logits
            return self.classification_head(pooled_output)
        
    @staticmethod
    def get_param_keys():
        """Returns a list of parameter keys used by the model's __init__ method."""
        return [
            'num_classes', 'img_patch_flat_dim', 'sig_patch_dim', 'num_img_patches', 
            'num_sig_patches', 'embed_dim', 'limoe_depth', 'limoe_heads', 
            'limoe_dim_head', 'limoe_num_experts', 'limoe_ff_mult', 'limoe_top_k', 
            'limoe_dropout', 'realnvp_img_layers', 'realnvp_sig_layers', 'proj_head_out_dim'
        ]